From FHIR Data to Balanced Neural Network BRCA Subtype Predictions¶

  1. Loading TCGA-BRCA cohort's Fast Healthcare Interoperability Resources (FHIR) Phenotype Data with GRaph Integration Platform (GRIP)
  2. Dimensionality Reduction with Autoencoder
  3. Classify Subtypes with Multimodel Nueral Network (with model evaluation and prediction)
  4. Improve Model to a Balanced Multimodal Neural Network Training (with model evaluation and prediction)
  5. Pathway enrichment analysis on Subtype's Model Features
  6. Applying GDC Trained Model to METABRIC for Cross-Dataset Subtype Prediction (short demo)
  7. Conclusion
    • Recap:
    • FHIR schema and grip enabled loading well-structured data.
    • Autoencoder reduced data dimensionality while preserving key features.
    • A balanced multimodal neural network, trained with SMOTETomek, improved classification accuracy for all subtypes.
    • Highlight:
    • Domain-specific expertise played a critical role in subtype classification.
    • In precision medicine, classifying cancer subtypes is essential for actionable insights and better outcomes.
    • This approach bridges biological data and AI, empowering robust subtype prediction workflows and empowering better treatment decisions in real-world scenarios.
In [20]:
import os
import sys
import json
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "notebook_connected"

from sklearn.preprocessing import StandardScaler, QuantileTransformer, MinMaxScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

import umap
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import roc_curve, roc_auc_score


from imblearn.combine import SMOTETomek

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.utils import to_categorical

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.preprocessing import label_binarize

import torch
import torch.nn as nn
import torch.optim as optim

from tensorflow.keras.callbacks import ReduceLROnPlateau

seed = 1234
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
tf.random.set_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["TF_DETERMINISTIC_OPS"] = "1"

import warnings
import logging

warnings.filterwarnings('ignore')
os.environ["KMP_WARNINGS"] = "0"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
sys.stderr = open(os.devnull, 'w')

tf.get_logger().setLevel(logging.ERROR)
tf.autograph.set_verbosity(0)

warnings.filterwarnings("ignore", category=UserWarning, module="tensorflow")
warnings.filterwarnings("ignore", category=FutureWarning)

1. GRIP query on TCGA-BRCA FHIR Data¶

  • Extracting breast cancer subtypes and associated patients from TCGA-BRCA FHIR transformed data
In [2]:
from IPython.display import Image
Image("../img/TCGA-BRCA-FHIRized.png", width=850) 
Out[2]:
No description has been provided for this image
In [3]:
import gripql

conn = gripql.Connection("http://localhost:8201")
G = conn.graph("gdc_brca")
G.query().V().hasLabel("Patient").execute()[:1] # look at structure.
Out[3]:
[{'gid': '28cd8c4a-2a0a-5090-8ee1-52cc2a9df946',
  'label': 'Patient',
  'data': {'deceasedBoolean': False,
   'extension': [{'url': 'http://hl7.org/fhir/us/core/StructureDefinition/us-core-birthsex',
     'valueCode': 'F'},
    {'url': 'http://hl7.org/fhir/us/core/StructureDefinition/us-core-race',
     'valueString': 'white'},
    {'url': 'http://hl7.org/fhir/us/core/StructureDefinition/us-core-ethnicity',
     'valueString': 'not hispanic or latino'},
    {'url': 'http://hl7.org/fhir/SearchParameter/patient-extensions-Patient-age',
     'valueQuantity': {'value': 58}}],
   'gender': 'female',
   'id': '28cd8c4a-2a0a-5090-8ee1-52cc2a9df946',
   'identifier': [{'system': 'https://gdc.cancer.gov/case_submitter_id',
     'use': 'secondary',
     'value': 'TCGA-BH-A0W3'},
    {'system': 'https://gdc.cancer.gov/case_id',
     'use': 'official',
     'value': '3c612e12-6de8-44fa-a095-805c45474821'}],
   'resourceType': 'Patient'}}]
In [4]:
subtypes = G.query().V().hasLabel("Observation").as_("observation").unwind("$.code").unwind("$.code.coding").has(gripql.eq("$.code.coding.code", "NCIT_C185941")).execute()
subtypes_patients = G.query().V().hasLabel("Observation").as_("observation").unwind("$.code").unwind("$.code.coding").has(gripql.eq("$.code.coding.code", "NCIT_C185941")).out("focus_Patient").execute()
In [5]:
def expand_metadata(patient_data) -> dict:
    exts = patient_data.get('extension', [])
    race = next((ext.get('valueString') for ext in exts if 'us-core-race' in ext.get('url', '')), None)
    ethnicity = next((ext.get('valueString') for ext in exts if 'us-core-ethnicity' in ext.get('url', '')), None)
    age = next((ext.get('valueQuantity', {}).get('value') for ext in exts if 'Patient-age' in ext.get('url', '')), None)
    
    idents = patient_data.get('identifier', [])
    secondary_identifier = next((iden.get('value') for iden in idents if iden.get('use') == 'secondary'), None)
    official_identifier = next((iden.get('value') for iden in idents if iden.get('use') == 'official'), None)
    
    return {
        'id': patient_data.get('id'),
        'official_identifier': official_identifier,
        'secondary_identifier': secondary_identifier,
        'gender': patient_data.get('gender'),
        'race': race,
        'ethnicity': ethnicity,
        'age': age,
        'deceased': patient_data.get('deceasedBoolean')
    }

# create a data frame for phenotypic data
df_pheno = pd.DataFrame([expand_metadata(dat['data']) for dat in subtypes_patients])
df_pheno['age'] = pd.to_numeric(df_pheno['age'], errors='coerce').fillna(0).astype(np.int64)

l = []
for item in subtypes:
    if 'valueString' in item['data'].keys():
        l.append(item['data']['valueString'])
df_pheno['subtype'] = l
# df_pheno[["secondary_identifier", "subtype", "gender", "race",	"ethnicity", "age", "deceased"]]
df_pheno[["secondary_identifier", "subtype"]] # this demo's focus is on subtypes
Out[5]:
secondary_identifier subtype
0 TCGA-3C-AAAU LumA
1 TCGA-3C-AALI Her2
2 TCGA-3C-AALJ LumB
3 TCGA-3C-AALK LumA
4 TCGA-4H-AAAK LumA
... ... ...
990 TCGA-WT-AB44 LumA
991 TCGA-XX-A899 LumA
992 TCGA-XX-A89A LumA
993 TCGA-Z7-A8R5 LumA
994 TCGA-Z7-A8R6 LumB

995 rows × 2 columns

In [6]:
print("Loading normalized mRNA expressions on TCGA-BRCA GDC samples...")
df_gdc = pd.read_csv("../data/TMP_20230209/BRCA_v12_20210228.tsv", sep="\t")
df_gdc = df_gdc[(~df_gdc.iloc[:, 2:].isna().all(axis=1)) & (~(df_gdc.iloc[:, 2:] == 0).all(axis=1))]
expression_cols = [col for col in df_gdc.columns if 'GEXP' in col]
df_gdc = df_gdc[['BRCA', 'Labels'] + expression_cols]
df_gdc.columns = [col.split(":")[3] if ':' in col else col for col in df_gdc.columns]

if df_gdc.columns.duplicated().any():
    print("Duplicate columns detected in GDC data. Removing duplicates...")
    df_gdc = df_gdc.loc[:, ~df_gdc.columns.duplicated()]

print("Joining grip query result on BRCA subtype's with normalized mRNA expressions...")
df_pheno.rename(columns={"secondary_identifier": "BRCA"}, inplace=True)
df_gdc = df_gdc.merge(df_pheno[['BRCA', 'subtype']], on='BRCA', how='left')
df_gdc.drop('Labels', axis=1, inplace=True)
df_gdc.rename(columns={"subtype": "Labels"}, inplace=True)
df_gdc.index = df_gdc.BRCA
demo_order = ['Labels'] + list(df_gdc.columns[3:10])
Loading normalized mRNA expressions on TCGA-BRCA GDC samples...
Duplicate columns detected in GDC data. Removing duplicates...
Joining grip query result on BRCA subtype's with normalized mRNA expressions...
In [7]:
subtype_colors = {
    "LumA": "#44AA99",  
    "LumB": "#FFC107",  
    "Basal": "#AA4499", 
    "Her2": "#449AE4"   
}

print("EDA - Evaluating data distribution...")
count_per_label = df_gdc['Labels'].value_counts()
print(count_per_label)
plt.figure(figsize=(6, 4))
sns.barplot(x=count_per_label.index, y=count_per_label.values, palette=[subtype_colors[label] for label in count_per_label.index])
plt.title("Subtype Class Count Distribution", fontsize=12)
plt.xlabel("Subtype", fontsize=10), plt.ylabel("Count", fontsize=10)
plt.xticks(rotation=45)
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.tight_layout()
plt.show()

# PCA on TCGA-BRCA normalized mRNA GDC Data
gene_expression_cols = df_gdc.drop(columns=["Labels", "BRCA"]).columns  
df_gdc_scaled = df_gdc[['Labels'] + list(gene_expression_cols)].copy()

scaler = StandardScaler()
df_gdc_scaled[gene_expression_cols] = scaler.fit_transform(df_gdc_scaled[gene_expression_cols])
pca = PCA(n_components=3)
df_gdc_scaled[['PCA1', 'PCA2', 'PCA3']] = pca.fit_transform(df_gdc_scaled[gene_expression_cols])

# 2D PCA 
plt.figure(figsize=(8, 6))
sns.scatterplot(data=df_gdc_scaled, x='PCA1', y='PCA2', hue='Labels', palette=subtype_colors, s=100, alpha=0.8)
plt.title("TCGA-BRCA scRNA-Seq mRNA Gene Expression PCA")
plt.xlabel("Principal Component 1"), plt.ylabel("Principal Component 2")
plt.legend(title='Labels', loc='upper right', fontsize=10)
plt.grid(True)
plt.tight_layout()
plt.show()

# # 3D PCA w plotly
# fig = px.scatter_3d(df_gdc_scaled, x='PCA1', y='PCA2', z='PCA3', color='Labels',
#                      title="TCGA-BRCA scRNA-Seq 3D PCA Visualization",
#                      labels={"PCA1": "Principal Component 1", "PCA2": "Principal Component 2", "PCA3": "Principal Component 3"},
#                      opacity=0.8, width=1200, height=900, color_discrete_map=subtype_colors)

# fig.update_traces(marker=dict(size=8))
# fig.show()
EDA - Evaluating data distribution...
Labels
LumA     535
LumB     205
Basal    175
Her2      80
Name: count, dtype: int64
No description has been provided for this image
No description has been provided for this image
  • Luminal A and Luminal B are both classified as hormone receptor-positive (HR+), meaning they express estrogen and progesterone receptors.
  • Luminal A is usually HER2-negative while Luminal B may be HER2-positive, which can impact treatment options.
  • Basal tumors lack estrogen receptors, making them "triple negative" (ER-, PR-, HER2-) while Luminal cancers are typically hormone receptor positive (ER+) and have a better outlook with hormone therapy options.

2. Dimensionality Reduction with Autoencoder¶

  • Reduce feature dimensionality while preserving key molecular characteristics.
  • Autoencoders compress the high-dimensional data into a more efficient representation.
In [9]:
class Autoencoder(nn.Module):
    def __init__(self, input_dim, encoding_dim=16, deep=False):
        super().__init__()
        enc_layers = [nn.Linear(input_dim, 512), nn.ReLU(), nn.Linear(512, 128), nn.ReLU()] if deep else [nn.Linear(input_dim, 128), nn.ReLU()]
        dec_layers = [nn.Linear(128, 512), nn.ReLU(), nn.Linear(512, input_dim)] if deep else [nn.Linear(128, input_dim)]
        self.encoder = nn.Sequential(*enc_layers, nn.Linear(128, encoding_dim), nn.ReLU())
        self.decoder = nn.Sequential(nn.Linear(encoding_dim, 128), nn.ReLU(), *dec_layers, nn.Sigmoid())

    def forward(self, x):
        return self.decoder(self.encoder(x))

        
def generate_embeddings(X, input_dim, encoding_dim=16, epochs=50, lr=0.001):
    model = Autoencoder(input_dim, encoding_dim, deep=False)
    criterion, optimizer = nn.MSELoss(), optim.Adam(model.parameters(), lr=lr)
    X_tensor = torch.tensor(X.values, dtype=torch.float32)

    for epoch in range(epochs):
        optimizer.zero_grad()
        loss = criterion(model(X_tensor), X_tensor)
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        return model.encoder(X_tensor).numpy()

gene_expression_cols = df_gdc_scaled.drop(columns=[col for col in ["Labels", "BRCA"] if col in df_gdc_scaled.columns]).columns
input_dim = len(gene_expression_cols)
embeddings = generate_embeddings(df_gdc_scaled[gene_expression_cols], input_dim)
embedding_cols = [f"Dim{i+1}" for i in range(16)]
df_gdc_scaled[embedding_cols] = embeddings
df_gdc_scaled["UMAP1"], df_gdc_scaled["UMAP2"] = umap.UMAP(n_components=2, random_state=42).fit_transform(embeddings).T

plt.figure(figsize=(8, 6))
sns.scatterplot(data=df_gdc_scaled, x="UMAP1", y="UMAP2", hue="Labels", palette=subtype_colors, s=60, alpha=0.8)
plt.title("UMAP Projection of Autoencoder Embeddings")
plt.xlabel("UMAP Dimension 1")
plt.ylabel("UMAP Dimension 2")
plt.legend(title="Subtype")
plt.grid(True)
plt.tight_layout()
plt.show()
No description has been provided for this image

3. Classify Subtypes with Multimodel Nueral Network¶

  • When classifying data into multiple subtype categories, only one subtype can be assigned to each patient/specimen data point.
    • softmax formula which is an activation for the multi-class scenario
    • loss='categorical_crossentropy': produces a one-hot array containing the probable match for each category
  • If one class is overrepresented or underrepresented, it might skew the model’s performance. Demo class imbalance in cancer subtype data (ex. Her2 is underrepresented).
In [10]:
class_maps = {0:"LumA",1:"LumB",2:"Basal",3:"Her2"}
X = df_gdc_scaled.drop(columns=[col for col in ["Labels", "BRCA"] if col in df_gdc_scaled.columns]).apply(pd.to_numeric, errors='coerce').fillna(0)
y = df_gdc['Labels'].astype('category').cat.codes
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=42)
y_train,y_test=to_categorical(y_train,num_classes=len(class_maps)),to_categorical(y_test,num_classes=len(class_maps))

@tf.function(reduce_retracing=True)
def f1_score(y_true, y_pred):
    """f1-score as a custom metric in Keras"""
    y_pred = K.round(y_pred) 
    tp = K.sum(K.cast(y_true * y_pred, 'float'), axis=0)
    precision = tp / (K.sum(K.cast(y_pred, 'float'), axis=0) + K.epsilon())
    recall = tp / (K.sum(K.cast(y_true, 'float'), axis=0) + K.epsilon())
    f1 = 2 * (precision * recall) / (precision + recall + K.epsilon())
    return K.mean(f1) 


def nn_multiclass_model(optimizer='adam', dropout_rate=0.2, hidden_units=64):
    model = Sequential([
        Dense(hidden_units, activation='relu', input_shape=(X_train.shape[1],)), 
        Dropout(dropout_rate),
        Dense(hidden_units, activation='relu'),
        Dense(len(class_maps), activation='softmax')
    ])
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy', f1_score], run_eagerly=False)
    return model

model_nn = nn_multiclass_model()
history = model_nn.fit(X_train, y_train, batch_size=32, epochs=20, verbose=1, validation_split=0.2)
test_loss, test_accuracy, test_f1 = model_nn.evaluate(X_test, y_test, verbose=0)
print(f"Test Accuracy: {test_accuracy:.3f}, Test F1-Score: {test_f1:.3f}")

y_pred_proba=model_nn.predict(X_test)
fpr,tpr,roc_auc={}, {},{}
for i in range(len(class_maps)):
    fpr[i],tpr[i],_=roc_curve(y_test[:,i],y_pred_proba[:,i])
    roc_auc[i]=roc_auc_score(y_test[:,i],y_pred_proba[:,i])

plt.figure(figsize=(8,6))
for i, label in class_maps.items():
    plt.plot(fpr[i], tpr[i], label=f'{label} (AUC={roc_auc[i]:.3f})', color=subtype_colors[label])
plt.plot([0,1],[0,1],'k--',linewidth=1,label="Random Guess")

plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve for Multi-Class Model")
plt.legend(loc="lower right")
plt.grid(alpha=0.3)
plt.show() 
Epoch 1/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - accuracy: 0.6594 - f1_score: 0.5168 - loss: 1.2764 - val_accuracy: 0.8188 - val_f1_score: 0.7718 - val_loss: 1.0477
Epoch 2/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9141 - f1_score: 0.8652 - loss: 0.3273 - val_accuracy: 0.8500 - val_f1_score: 0.8039 - val_loss: 0.9441
Epoch 3/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9204 - f1_score: 0.8925 - loss: 0.3254 - val_accuracy: 0.8750 - val_f1_score: 0.8237 - val_loss: 0.8913
Epoch 4/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9539 - f1_score: 0.9263 - loss: 0.1654 - val_accuracy: 0.8562 - val_f1_score: 0.8205 - val_loss: 0.9565
Epoch 5/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9728 - f1_score: 0.9385 - loss: 0.1357 - val_accuracy: 0.8875 - val_f1_score: 0.8872 - val_loss: 0.7037
Epoch 6/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9857 - f1_score: 0.9630 - loss: 0.0311 - val_accuracy: 0.8687 - val_f1_score: 0.8256 - val_loss: 1.0073
Epoch 7/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9763 - f1_score: 0.9430 - loss: 0.0878 - val_accuracy: 0.8562 - val_f1_score: 0.8113 - val_loss: 1.1788
Epoch 8/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9830 - f1_score: 0.9249 - loss: 0.0332 - val_accuracy: 0.8562 - val_f1_score: 0.8024 - val_loss: 1.1647
Epoch 9/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9954 - f1_score: 0.9754 - loss: 0.0370 - val_accuracy: 0.8500 - val_f1_score: 0.7704 - val_loss: 1.3605
Epoch 10/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9879 - f1_score: 0.9645 - loss: 0.0407 - val_accuracy: 0.8188 - val_f1_score: 0.7029 - val_loss: 1.4665
Epoch 11/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9867 - f1_score: 0.9601 - loss: 0.1005 - val_accuracy: 0.8687 - val_f1_score: 0.8163 - val_loss: 1.1282
Epoch 12/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9858 - f1_score: 0.9593 - loss: 0.0400 - val_accuracy: 0.8625 - val_f1_score: 0.8106 - val_loss: 1.1033
Epoch 13/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9984 - f1_score: 0.9787 - loss: 0.0084 - val_accuracy: 0.8750 - val_f1_score: 0.8372 - val_loss: 1.0305
Epoch 14/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9918 - f1_score: 0.9576 - loss: 0.0586 - val_accuracy: 0.8500 - val_f1_score: 0.7642 - val_loss: 1.2082
Epoch 15/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9895 - f1_score: 0.9422 - loss: 0.0488 - val_accuracy: 0.8625 - val_f1_score: 0.7903 - val_loss: 1.0896
Epoch 16/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9928 - f1_score: 0.9714 - loss: 0.0352 - val_accuracy: 0.8813 - val_f1_score: 0.8133 - val_loss: 1.1337
Epoch 17/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9923 - f1_score: 0.9688 - loss: 0.0171 - val_accuracy: 0.8625 - val_f1_score: 0.8102 - val_loss: 1.1552
Epoch 18/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9873 - f1_score: 0.9627 - loss: 0.0809 - val_accuracy: 0.8562 - val_f1_score: 0.7671 - val_loss: 1.1276
Epoch 19/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9826 - f1_score: 0.9608 - loss: 0.1708 - val_accuracy: 0.8687 - val_f1_score: 0.8596 - val_loss: 1.3592
Epoch 20/20
20/20 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9848 - f1_score: 0.9595 - loss: 0.0618 - val_accuracy: 0.8813 - val_f1_score: 0.8228 - val_loss: 1.5814
Test Accuracy: 0.879, Test F1-Score: 0.816
7/7 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 
No description has been provided for this image

4. Improve Model to a Balanced Multimodal Neural Network Training¶

In [11]:
def balance_data(X, y):
    return SMOTETomek(random_state=1234).fit_resample(X, y)

def train_weighted_nn(X_train, y_train, optimizer="adam", dropout_rate=0.3, hidden_units=128, epochs=30):
    class_weights = dict(enumerate(compute_class_weight(class_weight="balanced", classes=np.unique(y_train), y=y_train)))
    model = Sequential([
        Dense(hidden_units, activation="relu", input_shape=(X_train.shape[1],)), Dropout(dropout_rate),
        Dense(hidden_units, activation="relu"), Dense(len(np.unique(y_train)), activation="softmax")])
    model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
    lr_scheduler = ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5, verbose=1)
    model.fit(X_train, y_train, batch_size=32, epochs=epochs, verbose=1, validation_split=0.2, class_weight=class_weights, callbacks=[lr_scheduler])
    return model

def evaluate_model(model, X_test, y_test, class_names, colors):
    from sklearn.metrics import f1_score as sklearn_f1_score

    test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
    y_pred_proba = model.predict(X_test)
    y_pred_labels = np.argmax(y_pred_proba, axis=1)
    f1 = sklearn_f1_score(y_test, y_pred_labels, average="weighted")
    print(f"\nTest Accuracy: {test_accuracy:.4f}, F1 Score: {f1:.4f}")
    
    print("\nClassification Report:")
    print(classification_report(y_test, y_pred_labels, target_names=class_names))
    
    # confusion matrix
    cm = confusion_matrix(y_test, y_pred_labels)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix'), plt.xlabel('Predicted Labels'), plt.ylabel('True Labels'), plt.show()

    # roc curve
    y_test_binarized = label_binarize(y_test, classes=np.arange(len(class_names)))
    fpr, tpr, roc_auc = {}, {}, {}
    for i in range(len(class_names)):
        fpr[i], tpr[i], _ = roc_curve(y_test_binarized[:, i], y_pred_proba[:, i])
        roc_auc[i] = roc_auc_score(y_test_binarized[:, i], y_pred_proba[:, i])

    plt.figure(figsize=(6, 5))
    for i, class_name in enumerate(class_names):
        plt.plot(fpr[i], tpr[i], label=f'{class_name} (AUC = {roc_auc[i]:.3f})', color=colors[class_name])
    plt.plot([0, 1], [0, 1], 'k--', linewidth=1, label="Random Guess")
    plt.xlabel("False Positive Rate"), plt.ylabel("True Positive Rate"), plt.title("ROC Curve for Multi-Class Model")
    plt.legend(loc="lower right"), plt.grid(alpha=0.3), plt.show()

X_balanced, y_balanced = balance_data(X, y)
X_train, X_test, y_train, y_test = train_test_split(X_balanced, y_balanced, test_size=0.2, random_state=42)

model = train_weighted_nn(X_train, y_train)
evaluate_model(model, X_test, y_test, ['LumA', 'LumB', 'Basal', 'Her2'], subtype_colors)
Epoch 1/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - accuracy: 0.7436 - loss: 1.2957 - val_accuracy: 0.9155 - val_loss: 0.9389 - learning_rate: 0.0010
Epoch 2/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9162 - loss: 0.5372 - val_accuracy: 0.9329 - val_loss: 0.3616 - learning_rate: 0.0010
Epoch 3/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9478 - loss: 0.2965 - val_accuracy: 0.9271 - val_loss: 0.4012 - learning_rate: 0.0010
Epoch 4/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9683 - loss: 0.1596 - val_accuracy: 0.9475 - val_loss: 0.3184 - learning_rate: 0.0010
Epoch 5/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9645 - loss: 0.2702 - val_accuracy: 0.9329 - val_loss: 0.8069 - learning_rate: 0.0010
Epoch 6/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9648 - loss: 0.4444 - val_accuracy: 0.9592 - val_loss: 0.6680 - learning_rate: 0.0010
Epoch 7/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9736 - loss: 0.2701 - val_accuracy: 0.9563 - val_loss: 0.8974 - learning_rate: 0.0010
Epoch 8/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9776 - loss: 0.3327 - val_accuracy: 0.9592 - val_loss: 0.2596 - learning_rate: 0.0010
Epoch 9/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9821 - loss: 0.1119 - val_accuracy: 0.9650 - val_loss: 0.6723 - learning_rate: 0.0010
Epoch 10/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9937 - loss: 0.0463 - val_accuracy: 0.9563 - val_loss: 0.4833 - learning_rate: 0.0010
Epoch 11/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9863 - loss: 0.1311 - val_accuracy: 0.9534 - val_loss: 0.5132 - learning_rate: 0.0010
Epoch 12/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9870 - loss: 0.1036 - val_accuracy: 0.9650 - val_loss: 0.7490 - learning_rate: 0.0010
Epoch 13/30
41/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9961 - loss: 0.0125     
Epoch 13: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9961 - loss: 0.0125 - val_accuracy: 0.9592 - val_loss: 0.6229 - learning_rate: 0.0010
Epoch 14/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9942 - loss: 0.0188 - val_accuracy: 0.9621 - val_loss: 0.4736 - learning_rate: 5.0000e-04
Epoch 15/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9923 - loss: 0.0639 - val_accuracy: 0.9592 - val_loss: 0.5425 - learning_rate: 5.0000e-04
Epoch 16/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9969 - loss: 0.0171 - val_accuracy: 0.9621 - val_loss: 0.5996 - learning_rate: 5.0000e-04
Epoch 17/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9956 - loss: 0.0213 - val_accuracy: 0.9621 - val_loss: 0.4985 - learning_rate: 5.0000e-04
Epoch 18/30
41/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9957 - loss: 0.0265 
Epoch 18: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9958 - loss: 0.0269 - val_accuracy: 0.9534 - val_loss: 0.3806 - learning_rate: 5.0000e-04
Epoch 19/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9896 - loss: 0.0607 - val_accuracy: 0.9650 - val_loss: 0.4274 - learning_rate: 2.5000e-04
Epoch 20/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9981 - loss: 0.0112 - val_accuracy: 0.9650 - val_loss: 0.3677 - learning_rate: 2.5000e-04
Epoch 21/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 1.0000 - loss: 9.3295e-05 - val_accuracy: 0.9679 - val_loss: 0.3432 - learning_rate: 2.5000e-04
Epoch 22/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9976 - loss: 0.0126 - val_accuracy: 0.9650 - val_loss: 0.2957 - learning_rate: 2.5000e-04
Epoch 23/30
41/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9983 - loss: 0.0017     
Epoch 23: ReduceLROnPlateau reducing learning rate to 0.0001250000059371814.
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9983 - loss: 0.0018 - val_accuracy: 0.9679 - val_loss: 0.2973 - learning_rate: 2.5000e-04
Epoch 24/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9996 - loss: 0.0017 - val_accuracy: 0.9679 - val_loss: 0.3103 - learning_rate: 1.2500e-04
Epoch 25/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9994 - loss: 0.0014 - val_accuracy: 0.9679 - val_loss: 0.3207 - learning_rate: 1.2500e-04
Epoch 26/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 1.0000 - loss: 3.0701e-04 - val_accuracy: 0.9679 - val_loss: 0.3239 - learning_rate: 1.2500e-04
Epoch 27/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9998 - loss: 8.9701e-04 - val_accuracy: 0.9679 - val_loss: 0.3135 - learning_rate: 1.2500e-04
Epoch 28/30
41/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9987 - loss: 0.0061     
Epoch 28: ReduceLROnPlateau reducing learning rate to 6.25000029685907e-05.
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9987 - loss: 0.0067 - val_accuracy: 0.9679 - val_loss: 0.3119 - learning_rate: 1.2500e-04
Epoch 29/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9999 - loss: 6.6468e-04 - val_accuracy: 0.9679 - val_loss: 0.3102 - learning_rate: 6.2500e-05
Epoch 30/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9999 - loss: 0.0033 - val_accuracy: 0.9679 - val_loss: 0.3077 - learning_rate: 6.2500e-05
14/14 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step 

Test Accuracy: 0.9720, F1 Score: 0.9718

Classification Report:
              precision    recall  f1-score   support

        LumA       0.98      1.00      0.99       106
        LumB       0.97      1.00      0.99       105
       Basal       0.97      0.94      0.96       114
        Her2       0.96      0.95      0.96       103

    accuracy                           0.97       428
   macro avg       0.97      0.97      0.97       428
weighted avg       0.97      0.97      0.97       428

No description has been provided for this image
No description has been provided for this image

5. Pathway Enrichment Analysis on Subtype-Specific Model Features¶

  • Identify biologically significant pathways associated with subtype classifications.
In [12]:
output_weights = model.get_weights()[-2].T 
input_weights = model.get_weights()[0]  
feature_contributions = np.dot(input_weights, output_weights.T)  
feature_per_subtype_df = pd.DataFrame(feature_contributions, index=X.columns, columns=subtype_colors.keys())
sorted_feature_per_subtype_df = feature_per_subtype_df.apply(lambda x: x.abs().sort_values(ascending=False))

cleaned_df = sorted_feature_per_subtype_df[~sorted_feature_per_subtype_df.index.str.contains("Dim")]
cleaned_df.index = [g.split(":")[3] if ":" in g else g for g in cleaned_df.index]

top_n = 30
for subtype in subtype_colors.keys():
    df = cleaned_df[[subtype]].reset_index().rename(columns={"index": "Feature", subtype: "Importance"}).sort_values(by="Importance", ascending=False)
    df.to_csv(f"../data/subtype_features/{subtype}_feature_contributions.csv", index=False)
    print(f"{subtype}, top {top_n} Features:\n{','.join(df['Feature'].head(top_n))}\n")
LumA, top 30 Features:
MLLT6,ELOVL3,EGOT,PDE4D,SPINK4,RGS7,ACVR1B,GANAB,NDRG4,RGPD1,LBP,ZER1,ABCA5,CHRNB3,GNAT2,MAGEH1,C11orf82,GLDC,C18orf1,SCD,PDIA3P,LHFPL5,CHPT1,C11orf24,PRTG,FSTL3,C5orf20,SLC15A2,WBP2NL,TBC1D3

LumB, top 30 Features:
NGFR,HOXD9,TFAP2B,ODF3L2,C1RL,TNMD,FAM83E,ACCN4,NRIP2,SYT17,AGMAT,HTATIP2,FAM132A,CCNB1,ZNF610,GABRR3,PTTG1,SLC7A3,LOC100131691,FBXO39,F10,PTH,TAC1,PDSS1,POU5F1B,LOC441455,PRHOXNB,FAM72B,TCHH,CHRAC1

Basal, top 30 Features:
PDE4C,SURF1,OSGIN1,OLFM2,KRT6B,SPRR1B,RAP1GAP2,AGR3,TDO2,FAM83C,RASSF10,PPM1J,RHBDL1,TSIX,TSPAN11,STOX2,LRP6,UFD1L,DEFA6,FAM106A,IL20RB,KLK11,LHX8,SV2C,S100A8,ZNF556,CNDP2,MYBPHL,PPY2,IL1A

Her2, top 30 Features:
NR2C1,ZKSCAN2,PPFIBP2,FOXP4,CYFIP1,GDAP1L1,EFNA1,POLR3C,SLC9A4,KRT2,TMEM88,PDE4D,PART1,ZMYM6,NCALD,SFRS6,SUZ12P,C14orf86,CLMN,OR52B6,CALM1,UTP3,SLC32A1,SYT2,CSE1L,FAM157B,DYNLT3,UBE4A,FAM76B,PREX1

In [13]:
from IPython.display import Image, display
listOfImageNames = ['../img/graph.png',
                    '../img/foam.jpg']
for imageName in listOfImageNames:
    display(Image(filename=imageName))
No description has been provided for this image
No description has been provided for this image

6. Applying GDC Trained Model to METABRIC for Cross-Dataset Subtype Prediction (short demo)¶

6.1.a Harmonizing mRNA Expression Distributions via Z-Score Normalization for Cross-Dataset Subtype Prediction¶

In this section, we scale and standardize mRNA expression data from the GDC scRNA-Seq dataset (which includes breast cancer subtype labels) and the METABRIC dataset (without labels). Since the datasets are already normalized (e.g., via log-transformation), we apply Z-Score normalization to align their distributions. This supports downstream analyses and building neural network models that can reliably predict subtypes across differing data resources.

In [14]:
print("Loading and preprocessing METABRIC data...")
metabric_df = pd.read_csv("../data/data_mrna_agilent_microarray.txt", sep="\t", index_col=0)
metabric_df = metabric_df.drop("Entrez_Gene_Id", axis=1).transpose().dropna(axis=1)

# Group by gene names to handle duplicates
metabric_df = metabric_df.T.groupby(level=0).mean().T
genes_intersection = list(set(df_gdc.columns[2:]).intersection(set(metabric_df.columns)))
print(f"Common genes N = {len(genes_intersection)}")

df_gdc = df_gdc[['BRCA', 'Labels'] + genes_intersection]
metabric_df = metabric_df[genes_intersection]

# EDA: trial and error ----------------------------------------------
# df_gdc_log = np.log2(df_gdc[genes_intersection] + 1)
# metabric_df_log = np.log2(metabric_df + 1)
# quantile_transformer = QuantileTransformer(output_distribution='uniform', random_state=1234)
# metabric_quantile_normalized = quantile_transformer.fit_transform(metabric_df)
# metabric_quantile_normalized_df = pd.DataFrame(metabric_quantile_normalized, columns=genes_intersection, index=metabric_df.index)

combined_data = pd.concat([df_gdc[genes_intersection], metabric_df[genes_intersection]])
# combined_data = pd.concat([df_gdc_log, metabric_df_log])

print("Fitting scaling...\nData distribution summary: \n")
scaler = StandardScaler()
# scaler = MinMaxScaler()
scaler.fit(combined_data)

# transform GDC and METABRIC data 
# df_gdc_scaled = pd.DataFrame(scaler.transform(df_gdc_log), columns=genes_intersection, index=df_gdc.index)
# metabric_scaled = pd.DataFrame(scaler.transform(metabric_df_log), columns=genes_intersection, index=metabric_df.index)
df_gdc_scaled = pd.DataFrame(scaler.transform(df_gdc[genes_intersection]), columns=genes_intersection, index=df_gdc.index)
# metabric_scaled = pd.DataFrame(scaler.transform(metabric_df[genes_intersection]), columns=genes_intersection, index=metabric_df.index)
metabric_scaled = pd.DataFrame(scaler.fit_transform(metabric_df[genes_intersection]), columns=genes_intersection, index=metabric_df.index)

df_gdc_scaled['Labels'] = df_gdc['Labels']
df_gdc_scaled['BRCA'] = df_gdc['BRCA']

gdc_mean = df_gdc_scaled[genes_intersection].mean().mean()
metabric_mean = metabric_scaled.mean().mean()
gdc_sd = df_gdc_scaled[genes_intersection].std().mean()
metabric_sd = metabric_scaled.std().mean()

print(f"GDC scaled data mean: {gdc_mean:.2f}; sd: {gdc_sd:.2f}")
print(f"METABRIC scaled data mean: {metabric_mean:.2f}; sd: {metabric_sd:.2f}")
Loading and preprocessing METABRIC data...
Common genes N = 17129
Fitting scaling...
Data distribution summary: 

GDC scaled data mean: 0.62; sd: 1.10
METABRIC scaled data mean: -0.00; sd: 1.00
In [15]:
print("EDA - Evaluating data distributions with PCA....")
pca_gdc, pca_metabric = PCA(n_components=2), PCA(n_components=2)

df_gdc_pca = df_gdc_scaled.copy()
df_gdc_pca[['PCA1', 'PCA2']] = pca_gdc.fit_transform(df_gdc_scaled.drop(columns=['Labels', 'BRCA']))

metabric_pca = metabric_scaled.copy()
metabric_pca[['PCA1', 'PCA2']] = pca_metabric.fit_transform(metabric_scaled.drop(columns=['BRCA'], errors='ignore'))

plt.figure(figsize=(8, 6))
sns.scatterplot(data=df_gdc_pca, x='PCA1', y='PCA2', hue='Labels', palette=subtype_colors, s=80, alpha=0.8)
plt.title("PCA - Scaled GDC Data", fontsize=12)
plt.xlabel("PC 1")
plt.ylabel("PC 2")
plt.legend(title='Subtype', fontsize=8)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()


plt.figure(figsize=(8, 6))
sns.scatterplot(data=metabric_pca, x='PCA1', y='PCA2', palette="viridis", s=80, alpha=0.8)
plt.title("PCA - Scaled METABRIC Data", fontsize=12)
plt.xlabel("PC 1")
plt.ylabel("PC 2")
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()
EDA - Evaluating data distributions with PCA....
No description has been provided for this image
No description has been provided for this image
In [16]:
# subset w common genes between programs 
genes_intersection = list(set(df_gdc_scaled.columns).intersection(set(metabric_scaled.columns)))
X_gdc_scaled = df_gdc_scaled[genes_intersection]
X_metabric_scaled = metabric_scaled[genes_intersection]

input_dim, encoding_dim = len(genes_intersection), 16
autoencoder = Autoencoder(input_dim, encoding_dim, deep=True)
criterion, optimizer = nn.MSELoss(), optim.Adam(autoencoder.parameters(), lr=0.001)
X_tensor = torch.tensor(X_gdc_scaled.values, dtype=torch.float32)

num_epochs = 100
for epoch in range(num_epochs):
    optimizer.zero_grad()
    loss = criterion(autoencoder(X_tensor), X_tensor)
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

embeddings_gdc = autoencoder.encoder(X_tensor).detach().numpy()
embedding_cols = [f"Dim{i+1}" for i in range(encoding_dim)]
df_gdc_scaled[embedding_cols] = embeddings_gdc
subtype_embeddings = df_gdc_scaled.groupby("Labels")[embedding_cols].mean()

umap_model = umap.UMAP(n_components=2, random_state=42)
embeddings_umap = umap_model.fit_transform(embeddings_gdc)
df_gdc_scaled['UMAP1'], df_gdc_scaled['UMAP2'] = embeddings_umap[:, 0], embeddings_umap[:, 1]

plt.figure(figsize=(8, 6))
sns.scatterplot(data=df_gdc_scaled, x="UMAP1", y="UMAP2", hue="Labels", palette=subtype_colors)
plt.title("UMAP of GDC BRCA Subtype-Level Embeddings")
plt.xlabel("UMAP 1"), plt.ylabel("UMAP 2")
plt.legend(), plt.grid(True), plt.show()

X_train = df_gdc_scaled.drop(columns=["Labels", "BRCA"])
y_train = df_gdc_scaled["Labels"].astype('category').cat.codes

# handle class imbalance 
X_train_balanced, y_train_balanced = SMOTETomek(random_state=1234).fit_resample(X_train, y_train)
X_train_final, X_test_final, y_train_final, y_test_final = train_test_split(X_train_balanced, y_train_balanced, test_size=0.2, random_state=42)

class_weights = dict(enumerate(compute_class_weight(class_weight="balanced", classes=np.unique(y_train_final), y=y_train_final)))

def w_multiclass_model(optimizer="adam", dropout_rate=0.3, hidden_units=128):
    model = Sequential([
        Dense(hidden_units, activation="relu", input_shape=(X_train_final.shape[1],)), 
        Dropout(dropout_rate),
        Dense(hidden_units, activation="relu"),
        Dense(len(np.unique(y_train_final)), activation="softmax")
    ])
    model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=['accuracy'])
    return model

# train model
model = w_multiclass_model()
history = model.fit(X_train_final, y_train_final, batch_size=32, epochs=30, verbose=1, validation_split=0.2,
                    class_weight=class_weights, callbacks=[ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5, verbose=1)])

# evaluate model
test_loss, test_accuracy = model.evaluate(X_test_final, y_test_final, verbose=0)
print(f"\nTest Accuracy: {test_accuracy:.4f}")

y_pred_final = model.predict(X_test_final)
y_pred_labels = np.argmax(y_pred_final, axis=1)
print("\nClassification Report:")
print(classification_report(y_test_final, y_pred_labels, target_names=['LumA', 'LumB', 'Basal', 'Her2']))

plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix(y_test_final, y_pred_labels), annot=True, fmt='d', cmap='Blues',
            xticklabels=['LumA', 'LumB', 'Basal', 'Her2'], yticklabels=['LumA', 'LumB', 'Basal', 'Her2'])
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels'), plt.ylabel('True Labels')
plt.show()

y_test_binarized = label_binarize(y_test_final, classes=np.arange(len(['LumA', 'LumB', 'Basal', 'Her2'])))
y_pred_proba = y_pred_final

fpr, tpr, roc_auc = {}, {}, {}
for i, label in enumerate(['LumA', 'LumB', 'Basal', 'Her2']):
    fpr[i], tpr[i], _ = roc_curve(y_test_binarized[:, i], y_pred_proba[:, i])
    roc_auc[i] = roc_auc_score(y_test_binarized[:, i], y_pred_proba[:, i])

plt.figure(figsize=(8, 6))
for i, label in enumerate(['LumA', 'LumB', 'Basal', 'Her2']):
    plt.plot(fpr[i], tpr[i], label=f'{label} (AUC={roc_auc[i]:.3f})', color=subtype_colors[label])
plt.plot([0, 1], [0, 1], 'k--', linewidth=1, label="Random Guess")
plt.xlabel("False Positive Rate"), plt.ylabel("True Positive Rate")
plt.title("ROC Curve for Multi-Class Model")
plt.legend(loc="lower right"), plt.grid(alpha=0.3), plt.show()
Epoch [10/100], Loss: 1.6164
Epoch [20/100], Loss: 1.6136
Epoch [30/100], Loss: 1.6049
Epoch [40/100], Loss: 1.6002
Epoch [50/100], Loss: 1.5950
Epoch [60/100], Loss: 1.5886
Epoch [70/100], Loss: 1.5798
Epoch [80/100], Loss: 1.5682
Epoch [90/100], Loss: 1.5549
Epoch [100/100], Loss: 1.5288
No description has been provided for this image
Epoch 1/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - accuracy: 0.4992 - loss: 6.6012 - val_accuracy: 0.9242 - val_loss: 0.2100 - learning_rate: 0.0010
Epoch 2/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8272 - loss: 0.4747 - val_accuracy: 0.9300 - val_loss: 0.2104 - learning_rate: 0.0010
Epoch 3/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8917 - loss: 0.2863 - val_accuracy: 0.9359 - val_loss: 0.2142 - learning_rate: 0.0010
Epoch 4/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9153 - loss: 0.1838 - val_accuracy: 0.9708 - val_loss: 0.1216 - learning_rate: 0.0010
Epoch 5/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9458 - loss: 0.1342 - val_accuracy: 0.9417 - val_loss: 0.1546 - learning_rate: 0.0010
Epoch 6/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9678 - loss: 0.1018 - val_accuracy: 0.9446 - val_loss: 0.1558 - learning_rate: 0.0010
Epoch 7/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9814 - loss: 0.0706 - val_accuracy: 0.9738 - val_loss: 0.1010 - learning_rate: 0.0010
Epoch 8/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9764 - loss: 0.0668 - val_accuracy: 0.9650 - val_loss: 0.0856 - learning_rate: 0.0010
Epoch 9/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9802 - loss: 0.0603 - val_accuracy: 0.9563 - val_loss: 0.1204 - learning_rate: 0.0010
Epoch 10/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9767 - loss: 0.0606 - val_accuracy: 0.9796 - val_loss: 0.0685 - learning_rate: 0.0010
Epoch 11/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9668 - loss: 0.0925 - val_accuracy: 0.9708 - val_loss: 0.0872 - learning_rate: 0.0010
Epoch 12/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9557 - loss: 0.1656 - val_accuracy: 0.9563 - val_loss: 0.1581 - learning_rate: 0.0010
Epoch 13/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9620 - loss: 0.1041 - val_accuracy: 0.9679 - val_loss: 0.0627 - learning_rate: 0.0010
Epoch 14/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9576 - loss: 0.0723 - val_accuracy: 0.9679 - val_loss: 0.1013 - learning_rate: 0.0010
Epoch 15/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9720 - loss: 0.0809 - val_accuracy: 0.9738 - val_loss: 0.0775 - learning_rate: 0.0010
Epoch 16/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9712 - loss: 0.0884 - val_accuracy: 0.9738 - val_loss: 0.0736 - learning_rate: 0.0010
Epoch 17/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9717 - loss: 0.0702 - val_accuracy: 0.9679 - val_loss: 0.1302 - learning_rate: 0.0010
Epoch 18/30
37/43 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9778 - loss: 0.0647 
Epoch 18: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9778 - loss: 0.0640 - val_accuracy: 0.9738 - val_loss: 0.0952 - learning_rate: 0.0010
Epoch 19/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9761 - loss: 0.0516 - val_accuracy: 0.9708 - val_loss: 0.0826 - learning_rate: 5.0000e-04
Epoch 20/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9850 - loss: 0.0475 - val_accuracy: 0.9708 - val_loss: 0.0678 - learning_rate: 5.0000e-04
Epoch 21/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9851 - loss: 0.0314 - val_accuracy: 0.9708 - val_loss: 0.0771 - learning_rate: 5.0000e-04
Epoch 22/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9817 - loss: 0.0509 - val_accuracy: 0.9708 - val_loss: 0.0736 - learning_rate: 5.0000e-04
Epoch 23/30
40/43 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9816 - loss: 0.0332 
Epoch 23: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9817 - loss: 0.0334 - val_accuracy: 0.9767 - val_loss: 0.0751 - learning_rate: 5.0000e-04
Epoch 24/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9897 - loss: 0.0317 - val_accuracy: 0.9796 - val_loss: 0.0802 - learning_rate: 2.5000e-04
Epoch 25/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9857 - loss: 0.0277 - val_accuracy: 0.9650 - val_loss: 0.0868 - learning_rate: 2.5000e-04
Epoch 26/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9888 - loss: 0.0312 - val_accuracy: 0.9738 - val_loss: 0.0852 - learning_rate: 2.5000e-04
Epoch 27/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9929 - loss: 0.0179 - val_accuracy: 0.9767 - val_loss: 0.0818 - learning_rate: 2.5000e-04
Epoch 28/30
39/43 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9988 - loss: 0.0106 
Epoch 28: ReduceLROnPlateau reducing learning rate to 0.0001250000059371814.
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9985 - loss: 0.0110 - val_accuracy: 0.9767 - val_loss: 0.0892 - learning_rate: 2.5000e-04
Epoch 29/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9959 - loss: 0.0092 - val_accuracy: 0.9767 - val_loss: 0.0929 - learning_rate: 1.2500e-04
Epoch 30/30
43/43 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9878 - loss: 0.0369 - val_accuracy: 0.9708 - val_loss: 0.0948 - learning_rate: 1.2500e-04

Test Accuracy: 0.9813
14/14 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step 

Classification Report:
              precision    recall  f1-score   support

        LumA       1.00      1.00      1.00       106
        LumB       0.97      1.00      0.99       105
       Basal       1.00      0.93      0.96       114
        Her2       0.95      1.00      0.98       103

    accuracy                           0.98       428
   macro avg       0.98      0.98      0.98       428
weighted avg       0.98      0.98      0.98       428

No description has been provided for this image
No description has been provided for this image
Out[16]:
(<matplotlib.legend.Legend at 0x377f9ecd0>, None, None)
In [17]:
genes_intersection = list(set(df_gdc_scaled.columns).intersection(set(metabric_scaled.columns)))
X_metabric_scaled = metabric_scaled[genes_intersection]

if X_metabric_scaled.shape[1] < 17147:
    dummy_columns = np.zeros((X_metabric_scaled.shape[0], 17147 - X_metabric_scaled.shape[1]))  
    X_metabric_scaled_final = np.hstack((X_metabric_scaled, dummy_columns))  
else:
    X_metabric_scaled_final = X_metabric_scaled 

y_metabric_pred_proba = model.predict(X_metabric_scaled_final)
y_metabric_pred = np.argmax(y_metabric_pred_proba, axis=1)

metabric_labels = {0: "LumA", 1: "LumB", 2: "Basal", 3: "Her2"}
y_metabric_pred_labels = np.array([metabric_labels[label] for label in y_metabric_pred])

print("\nBRCA Subtype Predictions for METABRIC data:")
for i, label in enumerate(y_metabric_pred_labels[:10]):
    print(f"Sample {i+1}: Predicted subtype - {label}")

umap_metabric = umap.UMAP(n_components=2, random_state=42)
X_umap_metabric = umap_metabric.fit_transform(X_metabric_scaled_final)

umap_df = pd.DataFrame(X_umap_metabric, columns=["UMAP1", "UMAP2"])
umap_df["Predicted Subtype"] = y_metabric_pred_labels

plt.figure(figsize=(8, 6))
for subtype, color in subtype_colors.items():
    mask = y_metabric_pred_labels == subtype
    plt.scatter(umap_df.loc[mask, "UMAP1"], umap_df.loc[mask, "UMAP2"], 
                color=color, label=subtype, alpha=0.6, s=60)

plt.title("UMAP - Predicted BRCA Subtypes of METABRIC Data", fontsize=12)
plt.xlabel("UMAP 1", fontsize=12)
plt.ylabel("UMAP 2", fontsize=12)
plt.legend(title="Predicted Subtypes")
plt.grid(True)
plt.show()

pca = PCA(n_components=3)
X_metabric_pca = pca.fit_transform(X_metabric_scaled_final)
pca_df = pd.DataFrame(X_metabric_pca, columns=["PC1", "PC2", "PC3"])
pca_df["Predicted Subtype"] = y_metabric_pred_labels

# fig = px.scatter_3d(pca_df,x="PC1",y="PC2",z="PC3",color="Predicted Subtype",color_discrete_map=subtype_colors,  
#     title="PCA - Predicted BRCA Subtypes of METABRIC Data",
#     labels={"PC1": "Principal Component 1", "PC2": "Principal Component 2", "PC3": "Principal Component 3"},
#     width=1200,height=900)
# fig.show()
60/60 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step  

BRCA Subtype Predictions for METABRIC data:
Sample 1: Predicted subtype - Her2
Sample 2: Predicted subtype - LumB
Sample 3: Predicted subtype - Her2
Sample 4: Predicted subtype - LumB
Sample 5: Predicted subtype - Basal
Sample 6: Predicted subtype - Basal
Sample 7: Predicted subtype - LumB
Sample 8: Predicted subtype - Her2
Sample 9: Predicted subtype - LumA
Sample 10: Predicted subtype - LumB
No description has been provided for this image
In [18]:
color_map = {
    "LumA": "#44AA99",  
    "LumB": "#FFC107",  
    "Basal": "#AA4499", 
    "Her2": "#449AE4"   
}

metabric_tensor = torch.tensor(metabric_scaled.values, dtype=torch.float32)
embeddings_metabric = autoencoder.encoder(metabric_tensor).detach().numpy()

n_clusters = 4
kmeans_metabric = KMeans(n_clusters=n_clusters, random_state=42)
predicted_clusters_metabric = kmeans_metabric.fit_predict(embeddings_metabric)
class_maps = {0: "LumA", 1: "LumB", 2: "Basal", 3: "Her2"}
cluster_centers = kmeans_metabric.cluster_centers_ # centroid of each cluster

# compute the centroids of the known subtypes - using mean embeddings of the subtypes
subtype_centroids = {}
for cluster_num in range(n_clusters):
    cluster_indices = np.where(predicted_clusters_metabric == cluster_num)[0]
    cluster_embeddings = embeddings_metabric[cluster_indices]
    mean_embedding = np.mean(cluster_embeddings, axis=0)
    subtype_centroids[cluster_num] = mean_embedding

# find the closest subtype centroid - based on cosine similarity
cluster_labels = []
for cluster_center in cluster_centers:
    distances = {label: cosine_similarity([cluster_center], [centroid])[0][0] for label, centroid in subtype_centroids.items()}
    closest_label = min(distances, key=distances.get)
    cluster_labels.append(closest_label)

# map predicted clusters to corresponding labels
predicted_labels_metabric = [cluster_labels[cluster] for cluster in predicted_clusters_metabric]

# UMAP visualization of metabric subtype prediction embeddings - based on k-mean clustering and cosine similarity assignements of subtypes to cluster
umap_metabric = umap.UMAP(n_components=2, random_state=42)
X_umap_metabric = umap_metabric.fit_transform(embeddings_metabric)

plt.figure(figsize=(10, 8))
label_seen = set()
for i, label in enumerate(predicted_labels_metabric):
    color = color_map[class_maps[predicted_clusters_metabric[i]]]
    if class_maps[predicted_clusters_metabric[i]] not in label_seen:
        label_seen.add(class_maps[predicted_clusters_metabric[i]])
        plt.scatter(X_umap_metabric[i, 0], X_umap_metabric[i, 1], 
                    color=color, label=f"{class_maps[predicted_clusters_metabric[i]]} Cluster", alpha=0.6)
    else:
        plt.scatter(X_umap_metabric[i, 0], X_umap_metabric[i, 1], 
                    color=color, alpha=0.6)

plt.title('UMAP - Predicted BRCA Subtypes of METABRIC Embeddings\n k-mean clustering and cosine similarity assignements ', fontsize=12)
plt.xlabel('UMAP 1', fontsize=12)
plt.ylabel('UMAP 2', fontsize=12)
plt.legend(title="Clusters and Subtypes")
plt.grid(True)
plt.show()
No description has been provided for this image
In [22]:
from sklearn.metrics import f1_score as sklearn_f1_score
n_clusters = 4
kmeans_metabric = KMeans(n_clusters=n_clusters, random_state=42)
predicted_clusters_metabric = kmeans_metabric.fit_predict(embeddings_metabric)
cluster_centers = kmeans_metabric.cluster_centers_
class_maps = {0: "LumA", 1: "LumB", 2: "Basal", 3: "Her2"}

subtype_centroids = {}
for i in range(n_clusters):
    idx = np.where(predicted_clusters_metabric == i)[0]
    subtype_centroids[i] = np.mean(embeddings_metabric[idx], axis=0)

cluster_labels = []
for center in cluster_centers:
    distances = {label: cosine_similarity([center], [centroid])[0][0]
                 for label, centroid in subtype_centroids.items()}
    cluster_labels.append(min(distances, key=distances.get))

predicted_labels_metabric = [cluster_labels[cluster] for cluster in predicted_clusters_metabric]
true_labels_metabric = [class_maps[label] for label in predicted_clusters_metabric]
X_train, X_test, y_train, y_test = train_test_split(
    embeddings_metabric, true_labels_metabric, test_size=0.2, random_state=42
)
rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
rf_model.fit(X_train, y_train)
y_pred = rf_model.predict(X_test)

accuracy = accuracy_score(y_test, y_pred)
f1 = sklearn_f1_score(y_test, y_pred, average="weighted")

print(f"Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=['LumA', 'LumB', 'Basal', 'Her2']))

cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['LumA', 'LumB', 'Basal', 'Her2'],
            yticklabels=['LumA', 'LumB', 'Basal', 'Her2'])
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()
Accuracy: 0.9738, F1 Score: 0.9736

Classification Report:
              precision    recall  f1-score   support

        LumA       1.00      0.94      0.97        82
        LumB       1.00      0.91      0.95        54
       Basal       0.97      1.00      0.98       224
        Her2       0.88      1.00      0.93        21

    accuracy                           0.97       381
   macro avg       0.96      0.96      0.96       381
weighted avg       0.98      0.97      0.97       381

No description has been provided for this image
In [ ]: